# pyright: reportUnknownArgumentType=false, reportUnknownMemberType=false, reportAssignmentType=false, reportCallIssue=false, reportUnknownVariableType=false
from __future__ import annotations

from datetime import datetime, timedelta, timezone
from secrets import token_hex
from typing import Any, Optional, cast

import pandas as pd
import pytest
from openinference.semconv.trace import DocumentAttributes, SpanAttributes

from phoenix.client import AsyncClient
from phoenix.client import Client as SyncClient
from phoenix.client.__generated__ import v1
from phoenix.client.helpers.spans.rag import (
    async_get_input_output_context,
    async_get_retrieved_documents,
    get_input_output_context,
    get_retrieved_documents,
)

from .._helpers import (  # pyright: ignore[reportPrivateUsage]
    _AppInfo,
    _ExistingProject,
    _until_spans_exist,
)

# Aliases for common OpenInference attribute keys used in rag helper queries
DOCUMENT_CONTENT = DocumentAttributes.DOCUMENT_CONTENT
DOCUMENT_SCORE = DocumentAttributes.DOCUMENT_SCORE
DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
INPUT_VALUE = SpanAttributes.INPUT_VALUE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
METADATA = SpanAttributes.METADATA


def _doc(
    content: str, score: float | None = None, metadata: dict[str, Any] | None = None
) -> dict[str, Any]:
    d: dict[str, Any] = {"content": content}
    if score is not None:
        d["score"] = score
    if metadata is not None:
        d["metadata"] = metadata
    return d


def _create_root_span(
    *,
    trace_id: str,
    span_id: str,
    name: str = "root",
    start: datetime | None = None,
    duration_secs: float = 1.0,
    input_text: str = "question",
    output_text: str = "answer",
    extra_metadata: dict[str, Any] | None = None,
) -> v1.Span:
    start_time = (start or datetime.now(timezone.utc)).isoformat()
    end_time = (datetime.fromisoformat(start_time) + timedelta(seconds=duration_secs)).isoformat()
    attrs: dict[str, Any] = {
        INPUT_VALUE: input_text,
        OUTPUT_VALUE: output_text,
    }
    if extra_metadata is not None:
        attrs[METADATA] = extra_metadata
    return cast(
        v1.Span,
        {
            "name": name,
            "context": {"trace_id": trace_id, "span_id": span_id},
            "span_kind": "CHAIN",
            "parent_id": None,
            "start_time": start_time,
            "end_time": end_time,
            "status_code": "OK",
            "attributes": attrs,
        },
    )


def _create_retriever_span(
    *,
    trace_id: str,
    span_id: str,
    parent_id: Optional[str] = None,
    name: str = "retriever",
    start: datetime | None = None,
    duration_secs: float = 1.0,
    input_text: str = "retrieve for question",
    documents: list[dict[str, Any]] | None = None,
) -> v1.Span:
    start_time = (start or datetime.now(timezone.utc)).isoformat()
    end_time = (datetime.fromisoformat(start_time) + timedelta(seconds=duration_secs)).isoformat()
    docs = documents if documents is not None else []
    return cast(
        v1.Span,
        {
            "name": name,
            "context": {"trace_id": trace_id, "span_id": span_id},
            "parent_id": parent_id,
            "span_kind": "RETRIEVER",
            "start_time": start_time,
            "end_time": end_time,
            "status_code": "OK",
            "attributes": {
                "input": {"value": input_text},
                "retrieval": {"documents": [{"document": doc} for doc in docs]},
            },
        },
    )


class TestEvaluationHelpersRag:
    @pytest.mark.parametrize("is_async", [True, False])
    async def test_no_matching_spans(
        self,
        is_async: bool,
        _existing_project: _ExistingProject,
        _app: _AppInfo,
    ) -> None:
        api_key = _app.admin_secret

        project_name = _existing_project.name

        if is_async:
            client_async = AsyncClient(base_url=_app.base_url, api_key=api_key)
            client_sync = None
        else:
            client_sync = SyncClient(base_url=_app.base_url, api_key=api_key)
            client_async = None

        # No spans logged
        docs_df: pd.DataFrame
        if is_async:
            assert client_async is not None
            docs_df = await async_get_retrieved_documents(client_async, project_name=project_name)
        else:
            assert client_sync is not None
            docs_df = get_retrieved_documents(client_sync, project_name=project_name)
        assert isinstance(docs_df, pd.DataFrame)
        assert docs_df.empty

        qa_df: Optional[pd.DataFrame]
        if is_async:
            assert client_async is not None
            qa_df = await async_get_input_output_context(client_async, project_name=project_name)
        else:
            assert client_sync is not None
            qa_df = get_input_output_context(client_sync, project_name=project_name)
        assert qa_df is None

    @pytest.mark.parametrize("is_async", [True, False])
    async def test_retrieved_documents_basic_and_edge_cases(
        self,
        is_async: bool,
        _existing_project: _ExistingProject,
        _app: _AppInfo,
    ) -> None:
        """Covers basic explosion to rows, missing fields, and empty docs."""
        api_key = _app.admin_secret

        if is_async:
            client_async = AsyncClient(base_url=_app.base_url, api_key=api_key)
            client_sync = None
        else:
            client_sync = SyncClient(base_url=_app.base_url, api_key=api_key)
            client_async = None
        project_name = _existing_project.name

        trace_id = f"trace_retrieve_{token_hex(8)}"
        retriever_span_id = f"retr_{token_hex(8)}"
        # One doc with full fields, one doc missing score/metadata
        docs = [
            _doc("doc_1_content", score=0.9, metadata={"source": "a"}),
            _doc("doc_2_content"),
        ]
        retriever = _create_retriever_span(
            trace_id=trace_id,
            span_id=retriever_span_id,
            input_text="what is X?",
            documents=docs,
        )

        # Another retriever with empty docs -> contributes no rows
        empty_retriever_span_id = f"retr_empty_{token_hex(6)}"
        empty_retriever = _create_retriever_span(
            trace_id=trace_id,
            span_id=empty_retriever_span_id,
            input_text="empty case",
            documents=[],
        )

        # Log spans
        if is_async:
            assert client_async is not None
            create_result = await client_async.spans.log_spans(  # pyright: ignore[reportAttributeAccessIssue]
                project_identifier=project_name, spans=[retriever, empty_retriever]
            )
        else:
            assert client_sync is not None
            create_result = client_sync.spans.log_spans(  # pyright: ignore[reportAttributeAccessIssue]
                project_identifier=project_name, spans=[retriever, empty_retriever]
            )
        assert create_result["total_queued"] == 2
        await _until_spans_exist(_app, [retriever_span_id])

        df: pd.DataFrame
        if is_async:
            assert client_async is not None
            df = await async_get_retrieved_documents(client_async, project_name=project_name)
        else:
            assert client_sync is not None
            df = get_retrieved_documents(client_sync, project_name=project_name)

        assert isinstance(df, pd.DataFrame)
        # Focus only on rows for the retriever that has docs
        df_docs_only = df[df.index.get_level_values(0) == retriever_span_id]  # pyright: ignore[reportUnknownVariableType]
        assert len(df_docs_only) == 2
        # Expect multi-index with span_id and document position
        assert df_docs_only.index.nlevels == 2
        assert "context.trace_id" in df_docs_only.columns
        assert "input" in df_docs_only.columns
        # Input propagated from retriever span
        assert all(val == "what is X?" for val in df_docs_only["input"].tolist())  # pyright: ignore[reportUnknownVariableType]
        # Content and score/metadata assertions when available
        if "document" in df_docs_only.columns:
            documents = set(df_docs_only["document"].astype(str).tolist())  # pyright: ignore[reportAttributeAccessIssue,reportUnknownVariableType]
            assert "doc_1_content" in documents and "doc_2_content" in documents
        if "document_score" in df_docs_only.columns:
            has_missing = any(pd.isna(s) for s in df_docs_only["document_score"].tolist())  # pyright: ignore[reportArgumentType,reportUnknownVariableType]
            assert has_missing

    @pytest.mark.parametrize("is_async", [True, False])
    async def test_input_output_context_concatenation(
        self,
        is_async: bool,
        _existing_project: _ExistingProject,
        _app: _AppInfo,
    ) -> None:
        """Ensure concatenation across retriever spans/documents is correct."""
        api_key = _app.admin_secret

        if is_async:
            client_async = AsyncClient(base_url=_app.base_url, api_key=api_key)
            client_sync = None
        else:
            client_sync = SyncClient(base_url=_app.base_url, api_key=api_key)
            client_async = None
        project_name = _existing_project.name

        trace_id = f"trace_concat_{token_hex(8)}"
        root_span_id = f"root_{token_hex(8)}"
        retr1_id = f"retr1_{token_hex(6)}"
        retr2_id = f"retr2_{token_hex(6)}"
        root = _create_root_span(
            trace_id=trace_id,
            span_id=root_span_id,
            input_text="What is the capital of France?",
            output_text="Paris",
            extra_metadata={"task": "geography"},
        )
        retr1 = _create_retriever_span(
            trace_id=trace_id,
            span_id=retr1_id,
            parent_id=root_span_id,
            documents=[_doc("Paris is the capital city of France.", score=0.95)],
        )
        retr2 = _create_retriever_span(
            trace_id=trace_id,
            span_id=retr2_id,
            parent_id=root_span_id,
            documents=[_doc("France is a country in Western Europe.", score=0.8)],
        )

        if is_async:
            assert client_async is not None
            create_result = await client_async.spans.log_spans(  # pyright: ignore[reportAttributeAccessIssue]
                project_identifier=project_name, spans=[root, retr1, retr2]
            )
        else:
            assert client_sync is not None
            create_result = client_sync.spans.log_spans(  # pyright: ignore[reportAttributeAccessIssue]
                project_identifier=project_name, spans=[root, retr1, retr2]
            )
        assert create_result["total_queued"] == 3
        await _until_spans_exist(_app, [root_span_id, retr1_id, retr2_id])
        # Poll to account for eventual consistency of dataframe endpoints
        if is_async:
            assert client_async is not None
            qa_df2 = await async_get_input_output_context(
                client_async, project_name=project_name, timeout=15
            )
        else:
            assert client_sync is not None
            qa_df2 = get_input_output_context(client_sync, project_name=project_name, timeout=15)
        assert qa_df2 is not None
        assert isinstance(qa_df2, pd.DataFrame)
        # Index should be context.span_id, which should include the root span
        assert str(root_span_id) in qa_df2.index.astype(str)  # pyright: ignore[reportGeneralTypeIssues]
        row = qa_df2.loc[str(root_span_id)]
        assert row["input"] == "What is the capital of France?"  # pyright: ignore[reportGeneralTypeIssues]
        assert row["output"] == "Paris"  # pyright: ignore[reportGeneralTypeIssues]
        # Confirm concatenation contains both pieces and uses separator "\n\n"
        context_val = row["context"]
        assert isinstance(context_val, str)
        assert "Paris is the capital city of France." in context_val
        assert "France is a country in Western Europe." in context_val
        assert "\n\n" in context_val
        # Metadata is propagated
        assert (
            "metadata" in row  # pyright: ignore[reportGeneralTypeIssues]
            and isinstance(row["metadata"], dict)
            and row["metadata"]["task"] == "geography"
        )

    @pytest.mark.parametrize("is_async", [True, False])
    async def test_time_filtering_helpers(
        self,
        is_async: bool,
        _existing_project: _ExistingProject,
        _app: _AppInfo,
    ) -> None:
        """Verify start_time/end_time filter behavior for both helpers."""
        api_key = _app.admin_secret

        if is_async:
            client_async = AsyncClient(base_url=_app.base_url, api_key=api_key)
            client_sync = None
        else:
            client_sync = SyncClient(base_url=_app.base_url, api_key=api_key)
            client_async = None
        project_name = _existing_project.name

        base_time = datetime.now(timezone.utc)

        # Early trace (should be excluded by later filter)
        trace_early = f"trace_time_{token_hex(8)}"
        retr_early_id = f"retr_{token_hex(6)}"
        root_early = _create_root_span(
            trace_id=trace_early,
            span_id=f"root_{token_hex(6)}",
            start=base_time,
            input_text="early in time",
            output_text="early out",
        )
        retr_early = _create_retriever_span(
            trace_id=trace_early,
            span_id=retr_early_id,
            parent_id=root_early["context"]["span_id"],
            start=base_time,
            documents=[_doc("early doc")],
        )

        # Later trace (should be included)
        later_start = base_time + timedelta(seconds=30)
        trace_late = f"trace_time_{token_hex(8)}"
        retr_late_id = f"retr_{token_hex(6)}"
        root_late = _create_root_span(
            trace_id=trace_late,
            span_id=f"root_{token_hex(6)}",
            start=later_start,
            input_text="late in time",
            output_text="late out",
        )
        retr_late = _create_retriever_span(
            trace_id=trace_late,
            span_id=retr_late_id,
            parent_id=root_late["context"]["span_id"],
            start=later_start,
            documents=[_doc("late doc")],
        )

        if is_async:
            assert client_async is not None
            create_result = await client_async.spans.log_spans(  # pyright: ignore[reportAttributeAccessIssue]
                project_identifier=project_name,
                spans=[root_early, retr_early, root_late, retr_late],
            )
        else:
            assert client_sync is not None
            create_result = client_sync.spans.log_spans(  # pyright: ignore[reportAttributeAccessIssue]
                project_identifier=project_name,
                spans=[root_early, retr_early, root_late, retr_late],
            )
        assert create_result["total_queued"] == 4
        await _until_spans_exist(
            _app,
            [
                root_early["context"]["span_id"],
                retr_early_id,
                root_late["context"]["span_id"],
                retr_late_id,
            ],
        )

        # Filter to include only the later spans
        start_time = later_start - timedelta(seconds=1)
        end_time = later_start + timedelta(seconds=10)

        docs_df: pd.DataFrame
        if is_async:
            assert client_async is not None
            docs_df = await async_get_retrieved_documents(
                client_async,
                project_name=project_name,
                start_time=start_time,
                end_time=end_time,
                timeout=15,
            )
        else:
            assert client_sync is not None
            docs_df = get_retrieved_documents(
                client_sync,
                project_name=project_name,
                start_time=start_time,
                end_time=end_time,
                timeout=15,
            )
        assert isinstance(docs_df, pd.DataFrame)
        # Should only include "late doc"
        if "document" in docs_df.columns:
            documents: set[str] = (
                set(docs_df["document"].astype(str).tolist()) if not docs_df.empty else set()  # pyright: ignore[reportAttributeAccessIssue]
            )
            assert documents == {"late doc"}
        else:
            assert not docs_df.empty
            assert set(docs_df["context.trace_id"].astype(str).tolist()) == {trace_late}  # pyright: ignore[reportAttributeAccessIssue]
            assert docs_df.shape[0] == 1

        if is_async:
            assert client_async is not None
            qa_df = await async_get_input_output_context(
                client_async,
                project_name=project_name,
                start_time=start_time,
                end_time=end_time,
                timeout=15,
            )
        else:
            assert client_sync is not None
            qa_df = get_input_output_context(
                client_sync,
                project_name=project_name,
                start_time=start_time,
                end_time=end_time,
                timeout=15,
            )
        assert qa_df is not None
        assert isinstance(qa_df, pd.DataFrame)
        # Only the late root span should be present
        assert len(qa_df) == 1
        row = qa_df.iloc[0]
        assert row["input"] == "late in time"  # pyright: ignore[reportGeneralTypeIssues]
        assert row["output"] == "late out"  # pyright: ignore[reportGeneralTypeIssues]
        assert "late doc" in row["context"]
